Tarea realizada por:
alvaro.salinase@gmail.com - rol: 201073001-8martin.villanueva@alumnos.usm.cl - rol: 201104012-0DI UTFSM. Septiembre 2016.
import numpy as np
import matplotlib.pyplot as plt
# sklearn utilities
from sklearn.preprocessing import StandardScaler
from sklearn.cross_validation import KFold
from sklearn.cross_validation import train_test_split
# keras functionalities
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.optimizers import SGD
La función xor_generator(m) que se muestra a continuación genera $m$ datos aleatorios con la distribución que asemeja el comportamiento de XOR. La estrategia seguida fue generar puntos aleatorios en la región $[-1,1] \times [-1,1]$, etiquetando a los puntos de los cuadrantes I y III con $0$, y a los puntos de los cuadrantes II y IV con $1$.
Adicionalmente, para cada cuadrante se realizó una transformación lineal para alejar a los puntos de los ejes. Lo último es debido a que de otra manera, cuando existen puntos de distinta clase muy cercanos entre sà (en un eje), el problema se vuelve más complejo.
Observación: Esta distribución de puntos se dice simular XOR, pues simula el comportamiento de operador lógico exclusive or. Cuando dos puntos tienen coordenada $(x,y)$ de distintas signo se dice que pertenecen a la clase $1$, y cuando las coordenadas $(x,y)$ tienen mismo signo, pertenecen a la clase $0$.
# for reproducibility of experiments
np.random.seed(1)
def xor_generator(m):
"""
m: number of points to generate
"""
# generate m random points on [-1,1]x[-1,1]
X = 2.*np.random.random(m)-1.
Y = 2.*np.random.random(m)-1.
# arrangement to cluster the points
x_mask = X>0.
y_mask = Y>0.
X[x_mask] *= 0.8; X[x_mask] += 0.1
X[~x_mask] *= 0.8; X[~x_mask] -= 0.1
Y[y_mask] *= 0.8; Y[y_mask] += 0.1
Y[~y_mask] *= 0.8; Y[~y_mask] -= 0.1
XY = np.vstack([X,Y]).T
# generating the labels
y = np.zeros(m, dtype=np.int)
mask = np.multiply.reduce(XY, axis=1) > 0.
y[mask] = 1
return (XY, y)
def xor_plot(X,y):
X0 = X[y==0]
X1 = X[y==1]
plt.figure(figsize=(7,7))
plt.scatter(X0[:,0], X0[:,1], c='b')
plt.scatter(X1[:,0], X1[:,1], c='r')
plt.plot((-1.1,1.1),(0,0),'k--')
plt.plot((0,0),(-1.1,1.1),'k--')
plt.xlim([-1.1,1.1])
plt.ylim([-1.1,1.1])
plt.show()
Se generan a continuación 2000 datos aleatorios, mostrando su distribución gráficamente.
X,y = xor_generator(2000)
xor_plot(X,y)
A continuación se generan los conjuntos de entrenamiento, validación y prueba. De los 2000 datos generados, 1000 son de entrenamiento y 1000 de prueba. Sin embargo de los 1000 de entrenamiento, 20% se ha dejado para la validación durante el entrenamiento de la red.
Nota: Dado que los datos se generaron aleatoriamente, se ha omitido particionar estos conjuntos de forma aleatoria y se ha realizado de forma secuencial.
# training data
X_tr = X[0:800]
y_tr = y[0:800]
# validation data
X_val = X[800:1000]
y_val = y[800:1000]
# testing data
X_ts = X[1000::]
y_ts = y[1000::]
# training data
xor_plot(X_tr, y_tr)
En esta primera parte se ha construido una red consistente de una única neurona con las siguientes caracterÃsticas:
keras utiliza este valor para determinar la correspondiente clase.ADAM. Este es un eficiente y robusto algoritmo estocástico, basado en estimaciones de bajo orden para el momentum. Fue elegido para asegurar la convergencia. # building the model
model0 = Sequential()
model0.add(Dense(output_dim=1, input_dim=2, activation='sigmoid', init='normal'))
# compiling the model
model0.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# fitting the model
hist0 = model0.fit(X_tr, y_tr, nb_epoch=300, verbose=0, validation_data=(X_val, y_val))
La performance obtenida sobre el conjunto de prueba se muestra a continuación:
scores = model0.evaluate(X_ts, y_ts, verbose=0)
print("\n{0}: {1}%".format('Accuracy', scores[1]*100))
Accuracy: 54.8%
Este resultado nos indÃca que un simple neurona tiene un accuracy $\approx$50%, es decir, se comporta logra clasificar bien sólo la mitad de los datos (similar a un clasificador aleatorio).
Con lo cual se prueba experimentalmente que un sola neurona no resuelve el problema no lineal de XOR. Este es un resultado conocido, dado que, pese a que tiene una activación no lineal, la función de decisión sigue siendo lineal ($w^T x = 0$), y por lo tanto no es posible que logre clasificar data distribuida no linealmente.
Para esta parte se mantiene casi la misma configuración anterior, más unos pequeños ajustes:
tanh.# building the model
model1 = Sequential()
model1.add(Dense(output_dim=2, input_dim=2, activation='tanh', init='normal'))
model1.add(Dense(output_dim=1, init='normal', activation='sigmoid'))
# compiling the model
model1.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# fitting the model
hist1 = model1.fit(X_tr, y_tr, nb_epoch=300, verbose=0, validation_data=(X_val, y_val))
scores = model1.evaluate(X_ts, y_ts, verbose=0)
print("\n{0}: {1}%".format('Accuracy', scores[1]*100))
Accuracy: 91.4%
Como puede notarse el accuracy sobre el conjunto de prueba aumento de manera notoria. Esto nos demuestra experimentalmente que con tan solo agregar dos neuronas en una capa oculta, la red ahora es capaz de aprender fronteras de decisión no lineales.
A continuación se entrena el mismo modelo anterior, pero con 10 neuronas en la capa oculta.
# building the model
model2 = Sequential()
model2.add(Dense(output_dim=10, input_dim=2, activation='tanh', init='normal'))
model2.add(Dense(output_dim=1, init='normal', activation='sigmoid'))
# compiling the model
model2.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# fitting the model
hist2 = model2.fit(X_tr, y_tr, nb_epoch=300, verbose=0, validation_data=(X_val, y_val))
scores = model2.evaluate(X_ts, y_ts, verbose=0)
print("\n{0}: {1}%".format('Accuracy', scores[1]*100))
Accuracy: 100.0%
El resultado es un ajuste perfecto sobre el test de prueba. Esto va acorde al teorema de aproximación universal; Con una red de 3 capas, y con el número suficiente de neuronas en la capa oculta, es posible aproximar cualquier distribución/función.
"""
Input:
> hist - history object from model.fit() method
Output:
> plot of training and validation loss vs epochs
"""
def history_plot(hist, title=None):
fig = plt.figure(figsize=(16,8))
ax = fig.gca()
ax.set_xticks(np.linspace(1,301,20))
ax.set_yticks(np.linspace(0,1,10))
plt.xlim(0,301)
plt.ylim(0,1)
if title is None:
plt.title('Mean Squared Training and Validation Error')
else: plt.title(title)
plt.plot(range(1,301), hist.history['loss'], 'bo-', label='MSE train')
plt.plot(range(1,301), hist.history['val_loss'], 'go-', label='MSE validation')
plt.legend(loc=1)
plt.xlabel('Number of Epochs')
plt.ylabel('MSE')
plt.grid()
plt.show()
La lÃnea 4 del siguiente código se encarga de descargar el dataset correspondiente, leerlo (con valores separados por coma) e introducirlo dentro de un data frame de pandas con las etiquetas correspondientes para cada caracterÃstica.
La lÃnea 6 se encarga de generar los conjuntos de entrenamiento y validación. Sobre el data frame anterior, se realiza una separación de 75% para entrenamiento, y 25% para prueba. Como resultado se obtienen los data frame de entrenamiento y prueba respectivos.
import pandas as pd
url = 'http://mldata.org/repository/data/download/csv/regression-datasets-housing/'
df = pd.read_csv(url, sep=',',header=None, names=['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX',
'RM', 'AGE','DIS','RAD','TAX','PTRATIO','B','LSTAT','MEDV'])
from sklearn.cross_validation import train_test_split
df_train,df_test= train_test_split(df,test_size=0.25, random_state=0)
import pandas as pd
url = 'http://mldata.org/repository/data/download/csv/regression-datasets-housing/'
df = pd.read_csv(url, sep=',', header=None, names=['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX',
'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT', 'MEDV'])
from sklearn.cross_validation import train_test_split
df_train, df_test = train_test_split(df, test_size=0.25, random_state=0)
Como se puede ver de las celdas que siguen, el dataset se caracteriza por lo siguiente:
ZN, CHAS (categórica), RAD, TAX y PTRATIO.# mostrando algunas caracterÃsticas del dataset contenido en el frame
print('Numero de ejemplos: {0}'.format(df.shape[0]))
print('Numero de caracterÃsticas: {0}'.format(df.shape[1]))
Numero de ejemplos: 506 Numero de caracterÃsticas: 14
print(df.info)
<bound method DataFrame.info of CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX \
0 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296
1 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242
2 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242
3 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222
4 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222
5 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222
6 0.08829 12 7.87 0 0.524 6.012 66.6 5.5605 5 311
7 0.14455 12 7.87 0 0.524 6.172 96.1 5.9505 5 311
8 0.21124 12 7.87 0 0.524 5.631 100.0 6.0821 5 311
9 0.17004 12 7.87 0 0.524 6.004 85.9 6.5921 5 311
10 0.22489 12 7.87 0 0.524 6.377 94.3 6.3467 5 311
11 0.11747 12 7.87 0 0.524 6.009 82.9 6.2267 5 311
12 0.09378 12 7.87 0 0.524 5.889 39.0 5.4509 5 311
13 0.62976 0 8.14 0 0.538 5.949 61.8 4.7075 4 307
14 0.63796 0 8.14 0 0.538 6.096 84.5 4.4619 4 307
15 0.62739 0 8.14 0 0.538 5.834 56.5 4.4986 4 307
16 1.05393 0 8.14 0 0.538 5.935 29.3 4.4986 4 307
17 0.78420 0 8.14 0 0.538 5.990 81.7 4.2579 4 307
18 0.80271 0 8.14 0 0.538 5.456 36.6 3.7965 4 307
19 0.72580 0 8.14 0 0.538 5.727 69.5 3.7965 4 307
20 1.25179 0 8.14 0 0.538 5.570 98.1 3.7979 4 307
21 0.85204 0 8.14 0 0.538 5.965 89.2 4.0123 4 307
22 1.23247 0 8.14 0 0.538 6.142 91.7 3.9769 4 307
23 0.98843 0 8.14 0 0.538 5.813 100.0 4.0952 4 307
24 0.75026 0 8.14 0 0.538 5.924 94.1 4.3996 4 307
25 0.84054 0 8.14 0 0.538 5.599 85.7 4.4546 4 307
26 0.67191 0 8.14 0 0.538 5.813 90.3 4.6820 4 307
27 0.95577 0 8.14 0 0.538 6.047 88.8 4.4534 4 307
28 0.77299 0 8.14 0 0.538 6.495 94.4 4.4547 4 307
29 1.00245 0 8.14 0 0.538 6.674 87.3 4.2390 4 307
.. ... .. ... ... ... ... ... ... ... ...
476 4.87141 0 18.10 0 0.614 6.484 93.6 2.3053 24 666
477 15.02340 0 18.10 0 0.614 5.304 97.3 2.1007 24 666
478 10.23300 0 18.10 0 0.614 6.185 96.7 2.1705 24 666
479 14.33370 0 18.10 0 0.614 6.229 88.0 1.9512 24 666
480 5.82401 0 18.10 0 0.532 6.242 64.7 3.4242 24 666
481 5.70818 0 18.10 0 0.532 6.750 74.9 3.3317 24 666
482 5.73116 0 18.10 0 0.532 7.061 77.0 3.4106 24 666
483 2.81838 0 18.10 0 0.532 5.762 40.3 4.0983 24 666
484 2.37857 0 18.10 0 0.583 5.871 41.9 3.7240 24 666
485 3.67367 0 18.10 0 0.583 6.312 51.9 3.9917 24 666
486 5.69175 0 18.10 0 0.583 6.114 79.8 3.5459 24 666
487 4.83567 0 18.10 0 0.583 5.905 53.2 3.1523 24 666
488 0.15086 0 27.74 0 0.609 5.454 92.7 1.8209 4 711
489 0.18337 0 27.74 0 0.609 5.414 98.3 1.7554 4 711
490 0.20746 0 27.74 0 0.609 5.093 98.0 1.8226 4 711
491 0.10574 0 27.74 0 0.609 5.983 98.8 1.8681 4 711
492 0.11132 0 27.74 0 0.609 5.983 83.5 2.1099 4 711
493 0.17331 0 9.69 0 0.585 5.707 54.0 2.3817 6 391
494 0.27957 0 9.69 0 0.585 5.926 42.6 2.3817 6 391
495 0.17899 0 9.69 0 0.585 5.670 28.8 2.7986 6 391
496 0.28960 0 9.69 0 0.585 5.390 72.9 2.7986 6 391
497 0.26838 0 9.69 0 0.585 5.794 70.6 2.8927 6 391
498 0.23912 0 9.69 0 0.585 6.019 65.3 2.4091 6 391
499 0.17783 0 9.69 0 0.585 5.569 73.5 2.3999 6 391
500 0.22438 0 9.69 0 0.585 6.027 79.7 2.4982 6 391
501 0.06263 0 11.93 0 0.573 6.593 69.1 2.4786 1 273
502 0.04527 0 11.93 0 0.573 6.120 76.7 2.2875 1 273
503 0.06076 0 11.93 0 0.573 6.976 91.0 2.1675 1 273
504 0.10959 0 11.93 0 0.573 6.794 89.3 2.3889 1 273
505 0.04741 0 11.93 0 0.573 6.030 80.8 2.5050 1 273
PTRATIO B LSTAT MEDV
0 15 396.90 4.98 24.0
1 17 396.90 9.14 21.6
2 17 392.83 4.03 34.7
3 18 394.63 2.94 33.4
4 18 396.90 5.33 36.2
5 18 394.12 5.21 28.7
6 15 395.60 12.43 22.9
7 15 396.90 19.15 27.1
8 15 386.63 29.93 16.5
9 15 386.71 17.10 18.9
10 15 392.52 20.45 15.0
11 15 396.90 13.27 18.9
12 15 390.50 15.71 21.7
13 21 396.90 8.26 20.4
14 21 380.02 10.26 18.2
15 21 395.62 8.47 19.9
16 21 386.85 6.58 23.1
17 21 386.75 14.67 17.5
18 21 288.99 11.69 20.2
19 21 390.95 11.28 18.2
20 21 376.57 21.02 13.6
21 21 392.53 13.83 19.6
22 21 396.90 18.72 15.2
23 21 394.54 19.88 14.5
24 21 394.33 16.30 15.6
25 21 303.42 16.51 13.9
26 21 376.88 14.81 16.6
27 21 306.38 17.28 14.8
28 21 387.94 12.80 18.4
29 21 380.23 11.98 21.0
.. ... ... ... ...
476 20 396.21 18.68 16.7
477 20 349.48 24.91 12.0
478 20 379.70 18.03 14.6
479 20 383.32 13.11 21.4
480 20 396.90 10.74 23.0
481 20 393.07 7.74 23.7
482 20 395.28 7.01 25.0
483 20 392.92 10.42 21.8
484 20 370.73 13.34 20.6
485 20 388.62 10.58 21.2
486 20 392.68 14.98 19.1
487 20 388.22 11.45 20.6
488 20 395.09 18.06 15.2
489 20 344.05 23.97 7.0
490 20 318.43 29.68 8.1
491 20 390.11 18.07 13.6
492 20 396.90 13.35 20.1
493 19 396.90 12.01 21.8
494 19 396.90 13.59 24.5
495 19 393.29 17.60 23.1
496 19 396.90 21.14 19.7
497 19 396.90 14.10 18.3
498 19 396.90 12.92 21.2
499 19 395.77 15.10 17.5
500 19 396.90 14.33 16.8
501 21 391.99 9.67 22.4
502 21 396.90 9.08 20.6
503 21 396.90 5.64 23.9
504 21 393.45 6.48 22.0
505 21 396.90 7.88 11.9
[506 rows x 14 columns]>
df.describe
<bound method DataFrame.describe of CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX \
0 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296
1 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242
2 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242
3 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222
4 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222
5 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222
6 0.08829 12 7.87 0 0.524 6.012 66.6 5.5605 5 311
7 0.14455 12 7.87 0 0.524 6.172 96.1 5.9505 5 311
8 0.21124 12 7.87 0 0.524 5.631 100.0 6.0821 5 311
9 0.17004 12 7.87 0 0.524 6.004 85.9 6.5921 5 311
10 0.22489 12 7.87 0 0.524 6.377 94.3 6.3467 5 311
11 0.11747 12 7.87 0 0.524 6.009 82.9 6.2267 5 311
12 0.09378 12 7.87 0 0.524 5.889 39.0 5.4509 5 311
13 0.62976 0 8.14 0 0.538 5.949 61.8 4.7075 4 307
14 0.63796 0 8.14 0 0.538 6.096 84.5 4.4619 4 307
15 0.62739 0 8.14 0 0.538 5.834 56.5 4.4986 4 307
16 1.05393 0 8.14 0 0.538 5.935 29.3 4.4986 4 307
17 0.78420 0 8.14 0 0.538 5.990 81.7 4.2579 4 307
18 0.80271 0 8.14 0 0.538 5.456 36.6 3.7965 4 307
19 0.72580 0 8.14 0 0.538 5.727 69.5 3.7965 4 307
20 1.25179 0 8.14 0 0.538 5.570 98.1 3.7979 4 307
21 0.85204 0 8.14 0 0.538 5.965 89.2 4.0123 4 307
22 1.23247 0 8.14 0 0.538 6.142 91.7 3.9769 4 307
23 0.98843 0 8.14 0 0.538 5.813 100.0 4.0952 4 307
24 0.75026 0 8.14 0 0.538 5.924 94.1 4.3996 4 307
25 0.84054 0 8.14 0 0.538 5.599 85.7 4.4546 4 307
26 0.67191 0 8.14 0 0.538 5.813 90.3 4.6820 4 307
27 0.95577 0 8.14 0 0.538 6.047 88.8 4.4534 4 307
28 0.77299 0 8.14 0 0.538 6.495 94.4 4.4547 4 307
29 1.00245 0 8.14 0 0.538 6.674 87.3 4.2390 4 307
.. ... .. ... ... ... ... ... ... ... ...
476 4.87141 0 18.10 0 0.614 6.484 93.6 2.3053 24 666
477 15.02340 0 18.10 0 0.614 5.304 97.3 2.1007 24 666
478 10.23300 0 18.10 0 0.614 6.185 96.7 2.1705 24 666
479 14.33370 0 18.10 0 0.614 6.229 88.0 1.9512 24 666
480 5.82401 0 18.10 0 0.532 6.242 64.7 3.4242 24 666
481 5.70818 0 18.10 0 0.532 6.750 74.9 3.3317 24 666
482 5.73116 0 18.10 0 0.532 7.061 77.0 3.4106 24 666
483 2.81838 0 18.10 0 0.532 5.762 40.3 4.0983 24 666
484 2.37857 0 18.10 0 0.583 5.871 41.9 3.7240 24 666
485 3.67367 0 18.10 0 0.583 6.312 51.9 3.9917 24 666
486 5.69175 0 18.10 0 0.583 6.114 79.8 3.5459 24 666
487 4.83567 0 18.10 0 0.583 5.905 53.2 3.1523 24 666
488 0.15086 0 27.74 0 0.609 5.454 92.7 1.8209 4 711
489 0.18337 0 27.74 0 0.609 5.414 98.3 1.7554 4 711
490 0.20746 0 27.74 0 0.609 5.093 98.0 1.8226 4 711
491 0.10574 0 27.74 0 0.609 5.983 98.8 1.8681 4 711
492 0.11132 0 27.74 0 0.609 5.983 83.5 2.1099 4 711
493 0.17331 0 9.69 0 0.585 5.707 54.0 2.3817 6 391
494 0.27957 0 9.69 0 0.585 5.926 42.6 2.3817 6 391
495 0.17899 0 9.69 0 0.585 5.670 28.8 2.7986 6 391
496 0.28960 0 9.69 0 0.585 5.390 72.9 2.7986 6 391
497 0.26838 0 9.69 0 0.585 5.794 70.6 2.8927 6 391
498 0.23912 0 9.69 0 0.585 6.019 65.3 2.4091 6 391
499 0.17783 0 9.69 0 0.585 5.569 73.5 2.3999 6 391
500 0.22438 0 9.69 0 0.585 6.027 79.7 2.4982 6 391
501 0.06263 0 11.93 0 0.573 6.593 69.1 2.4786 1 273
502 0.04527 0 11.93 0 0.573 6.120 76.7 2.2875 1 273
503 0.06076 0 11.93 0 0.573 6.976 91.0 2.1675 1 273
504 0.10959 0 11.93 0 0.573 6.794 89.3 2.3889 1 273
505 0.04741 0 11.93 0 0.573 6.030 80.8 2.5050 1 273
PTRATIO B LSTAT MEDV
0 15 396.90 4.98 24.0
1 17 396.90 9.14 21.6
2 17 392.83 4.03 34.7
3 18 394.63 2.94 33.4
4 18 396.90 5.33 36.2
5 18 394.12 5.21 28.7
6 15 395.60 12.43 22.9
7 15 396.90 19.15 27.1
8 15 386.63 29.93 16.5
9 15 386.71 17.10 18.9
10 15 392.52 20.45 15.0
11 15 396.90 13.27 18.9
12 15 390.50 15.71 21.7
13 21 396.90 8.26 20.4
14 21 380.02 10.26 18.2
15 21 395.62 8.47 19.9
16 21 386.85 6.58 23.1
17 21 386.75 14.67 17.5
18 21 288.99 11.69 20.2
19 21 390.95 11.28 18.2
20 21 376.57 21.02 13.6
21 21 392.53 13.83 19.6
22 21 396.90 18.72 15.2
23 21 394.54 19.88 14.5
24 21 394.33 16.30 15.6
25 21 303.42 16.51 13.9
26 21 376.88 14.81 16.6
27 21 306.38 17.28 14.8
28 21 387.94 12.80 18.4
29 21 380.23 11.98 21.0
.. ... ... ... ...
476 20 396.21 18.68 16.7
477 20 349.48 24.91 12.0
478 20 379.70 18.03 14.6
479 20 383.32 13.11 21.4
480 20 396.90 10.74 23.0
481 20 393.07 7.74 23.7
482 20 395.28 7.01 25.0
483 20 392.92 10.42 21.8
484 20 370.73 13.34 20.6
485 20 388.62 10.58 21.2
486 20 392.68 14.98 19.1
487 20 388.22 11.45 20.6
488 20 395.09 18.06 15.2
489 20 344.05 23.97 7.0
490 20 318.43 29.68 8.1
491 20 390.11 18.07 13.6
492 20 396.90 13.35 20.1
493 19 396.90 12.01 21.8
494 19 396.90 13.59 24.5
495 19 393.29 17.60 23.1
496 19 396.90 21.14 19.7
497 19 396.90 14.10 18.3
498 19 396.90 12.92 21.2
499 19 395.77 15.10 17.5
500 19 396.90 14.33 16.8
501 21 391.99 9.67 22.4
502 21 396.90 9.08 20.6
503 21 396.90 5.64 23.9
504 21 393.45 6.48 22.0
505 21 396.90 7.88 11.9
[506 rows x 14 columns]>
Si bien en problemas de regresión, los modelos obtenidos al escalar los datos son todos equivalentes, con modificaciones en el tamaño de los pesos asociado a cada caracterÃstica, en el caso particular de las ANN ciertas consideraciones numéricas/prácticas para los cuales la estandarización/normalización ayudan en gran medida.
Nota: Para tener resultados consistentes, se ha ocupa la misma transformación sobre datos de entrenamiento y prueba.
# scaling for training set
scaler = StandardScaler().fit(df_train)
df_train_scaled = pd.DataFrame(scaler.transform(df_train), columns=df_train.columns)
y_train_scaled = df_train_scaled.pop('MEDV')
X_train_scaled = df_train_scaled
# the same but for testing set
df_test_scaled = pd.DataFrame(scaler.transform(df_test), columns=df_test.columns)
y_test_scaled = df_test_scaled.pop('MEDV')
X_test_scaled = df_test_scaled
# building the model
model = Sequential()
model.add(Dense(output_dim=200, input_dim=X_train_scaled.shape[1], init='uniform'))
model.add(Activation('sigmoid'))
model.add(Dense(1, init='uniform'))
model.add(Activation('linear'))
# compiling the model
model.compile(optimizer=SGD(lr=0.01), loss='mean_squared_error')
# fitting the model
hist = model.fit(X_train_scaled.as_matrix(), y_train_scaled.as_matrix(), nb_epoch=300,
verbose=1, validation_data=(X_test_scaled.as_matrix(), y_test_scaled.as_matrix()))
Train on 379 samples, validate on 127 samples Epoch 1/300 379/379 [==============================] - 0s - loss: 1.0437 - val_loss: 0.9331 Epoch 2/300 379/379 [==============================] - 0s - loss: 0.9628 - val_loss: 0.9264 Epoch 3/300 379/379 [==============================] - 0s - loss: 0.9537 - val_loss: 0.8893 Epoch 4/300 379/379 [==============================] - 0s - loss: 0.9141 - val_loss: 0.8678 Epoch 5/300 379/379 [==============================] - 0s - loss: 0.9428 - val_loss: 0.8625 Epoch 6/300 379/379 [==============================] - 0s - loss: 0.8840 - val_loss: 0.8175 Epoch 7/300 379/379 [==============================] - 0s - loss: 0.8565 - val_loss: 0.8002 Epoch 8/300 379/379 [==============================] - 0s - loss: 0.7851 - val_loss: 0.8075 Epoch 9/300 379/379 [==============================] - 0s - loss: 0.7294 - val_loss: 0.7457 Epoch 10/300 379/379 [==============================] - 0s - loss: 0.7194 - val_loss: 0.7182 Epoch 11/300 379/379 [==============================] - 0s - loss: 0.6661 - val_loss: 0.6992 Epoch 12/300 379/379 [==============================] - 0s - loss: 0.6136 - val_loss: 0.6673 Epoch 13/300 379/379 [==============================] - 0s - loss: 0.5892 - val_loss: 0.6509 Epoch 14/300 379/379 [==============================] - 0s - loss: 0.5615 - val_loss: 0.6707 Epoch 15/300 379/379 [==============================] - 0s - loss: 0.5468 - val_loss: 0.6647 Epoch 16/300 379/379 [==============================] - 0s - loss: 0.5092 - val_loss: 0.5999 Epoch 17/300 379/379 [==============================] - 0s - loss: 0.4789 - val_loss: 0.6155 Epoch 18/300 379/379 [==============================] - 0s - loss: 0.4750 - val_loss: 0.5818 Epoch 19/300 379/379 [==============================] - 0s - loss: 0.4436 - val_loss: 0.5839 Epoch 20/300 379/379 [==============================] - 0s - loss: 0.4292 - val_loss: 0.5662 Epoch 21/300 379/379 [==============================] - 0s - loss: 0.4137 - val_loss: 0.5833 Epoch 22/300 379/379 [==============================] - 0s - loss: 0.3927 - val_loss: 0.5550 Epoch 23/300 379/379 [==============================] - 0s - loss: 0.3960 - val_loss: 0.5575 Epoch 24/300 379/379 [==============================] - 0s - loss: 0.3840 - val_loss: 0.5508 Epoch 25/300 379/379 [==============================] - 0s - loss: 0.3677 - val_loss: 0.5482 Epoch 26/300 379/379 [==============================] - 0s - loss: 0.3693 - val_loss: 0.5385 Epoch 27/300 379/379 [==============================] - 0s - loss: 0.3682 - val_loss: 0.5501 Epoch 28/300 379/379 [==============================] - 0s - loss: 0.3461 - val_loss: 0.5609 Epoch 29/300 379/379 [==============================] - 0s - loss: 0.3438 - val_loss: 0.5178 Epoch 30/300 379/379 [==============================] - 0s - loss: 0.3325 - val_loss: 0.5264 Epoch 31/300 379/379 [==============================] - 0s - loss: 0.3279 - val_loss: 0.5083 Epoch 32/300 379/379 [==============================] - 0s - loss: 0.3227 - val_loss: 0.5149 Epoch 33/300 379/379 [==============================] - 0s - loss: 0.3174 - val_loss: 0.4949 Epoch 34/300 379/379 [==============================] - 0s - loss: 0.3226 - val_loss: 0.4970 Epoch 35/300 379/379 [==============================] - 0s - loss: 0.3246 - val_loss: 0.4853 Epoch 36/300 379/379 [==============================] - 0s - loss: 0.2997 - val_loss: 0.4961 Epoch 37/300 379/379 [==============================] - 0s - loss: 0.3108 - val_loss: 0.4796 Epoch 38/300 379/379 [==============================] - 0s - loss: 0.2956 - val_loss: 0.4905 Epoch 39/300 379/379 [==============================] - 0s - loss: 0.2908 - val_loss: 0.4732 Epoch 40/300 379/379 [==============================] - 0s - loss: 0.2982 - val_loss: 0.4654 Epoch 41/300 379/379 [==============================] - 0s - loss: 0.2954 - val_loss: 0.4606 Epoch 42/300 379/379 [==============================] - 0s - loss: 0.2905 - val_loss: 0.4550 Epoch 43/300 379/379 [==============================] - 0s - loss: 0.2810 - val_loss: 0.4780 Epoch 44/300 379/379 [==============================] - 0s - loss: 0.2819 - val_loss: 0.4527 Epoch 45/300 379/379 [==============================] - 0s - loss: 0.2801 - val_loss: 0.4522 Epoch 46/300 379/379 [==============================] - 0s - loss: 0.2821 - val_loss: 0.4672 Epoch 47/300 379/379 [==============================] - 0s - loss: 0.2684 - val_loss: 0.4368 Epoch 48/300 379/379 [==============================] - 0s - loss: 0.2815 - val_loss: 0.4913 Epoch 49/300 379/379 [==============================] - 0s - loss: 0.2829 - val_loss: 0.4419 Epoch 50/300 379/379 [==============================] - 0s - loss: 0.2631 - val_loss: 0.4516 Epoch 51/300 379/379 [==============================] - 0s - loss: 0.2668 - val_loss: 0.4597 Epoch 52/300 379/379 [==============================] - 0s - loss: 0.2715 - val_loss: 0.4366 Epoch 53/300 379/379 [==============================] - 0s - loss: 0.2604 - val_loss: 0.4325 Epoch 54/300 379/379 [==============================] - 0s - loss: 0.2647 - val_loss: 0.4258 Epoch 55/300 379/379 [==============================] - 0s - loss: 0.2594 - val_loss: 0.4204 Epoch 56/300 379/379 [==============================] - 0s - loss: 0.2634 - val_loss: 0.4526 Epoch 57/300 379/379 [==============================] - 0s - loss: 0.2590 - val_loss: 0.4421 Epoch 58/300 379/379 [==============================] - 0s - loss: 0.2555 - val_loss: 0.4262 Epoch 59/300 379/379 [==============================] - 0s - loss: 0.2546 - val_loss: 0.4173 Epoch 60/300 379/379 [==============================] - 0s - loss: 0.2553 - val_loss: 0.4269 Epoch 61/300 379/379 [==============================] - 0s - loss: 0.2565 - val_loss: 0.4232 Epoch 62/300 379/379 [==============================] - 0s - loss: 0.2776 - val_loss: 0.4092 Epoch 63/300 379/379 [==============================] - 0s - loss: 0.2595 - val_loss: 0.4145 Epoch 64/300 379/379 [==============================] - 0s - loss: 0.2610 - val_loss: 0.4132 Epoch 65/300 379/379 [==============================] - 0s - loss: 0.2568 - val_loss: 0.4076 Epoch 66/300 379/379 [==============================] - 0s - loss: 0.2550 - val_loss: 0.4073 Epoch 67/300 379/379 [==============================] - 0s - loss: 0.2541 - val_loss: 0.4445 Epoch 68/300 379/379 [==============================] - 0s - loss: 0.2574 - val_loss: 0.4033 Epoch 69/300 379/379 [==============================] - 0s - loss: 0.2488 - val_loss: 0.4588 Epoch 70/300 379/379 [==============================] - 0s - loss: 0.2704 - val_loss: 0.4030 Epoch 71/300 379/379 [==============================] - 0s - loss: 0.2547 - val_loss: 0.4146 Epoch 72/300 379/379 [==============================] - 0s - loss: 0.2613 - val_loss: 0.4566 Epoch 73/300 379/379 [==============================] - 0s - loss: 0.2608 - val_loss: 0.4146 Epoch 74/300 379/379 [==============================] - 0s - loss: 0.2502 - val_loss: 0.4061 Epoch 75/300 379/379 [==============================] - 0s - loss: 0.2496 - val_loss: 0.3970 Epoch 76/300 379/379 [==============================] - 0s - loss: 0.2500 - val_loss: 0.3963 Epoch 77/300 379/379 [==============================] - 0s - loss: 0.2474 - val_loss: 0.4007 Epoch 78/300 379/379 [==============================] - 0s - loss: 0.2650 - val_loss: 0.4107 Epoch 79/300 379/379 [==============================] - 0s - loss: 0.2563 - val_loss: 0.4074 Epoch 80/300 379/379 [==============================] - 0s - loss: 0.2417 - val_loss: 0.4070 Epoch 81/300 379/379 [==============================] - 0s - loss: 0.2543 - val_loss: 0.3960 Epoch 82/300 379/379 [==============================] - 0s - loss: 0.2660 - val_loss: 0.4092 Epoch 83/300 379/379 [==============================] - 0s - loss: 0.2603 - val_loss: 0.4033 Epoch 84/300 379/379 [==============================] - 0s - loss: 0.2467 - val_loss: 0.3979 Epoch 85/300 379/379 [==============================] - 0s - loss: 0.2474 - val_loss: 0.4152 Epoch 86/300 379/379 [==============================] - 0s - loss: 0.2453 - val_loss: 0.3910 Epoch 87/300 379/379 [==============================] - 0s - loss: 0.2556 - val_loss: 0.3924 Epoch 88/300 379/379 [==============================] - 0s - loss: 0.2620 - val_loss: 0.3877 Epoch 89/300 379/379 [==============================] - 0s - loss: 0.2477 - val_loss: 0.3975 Epoch 90/300 379/379 [==============================] - 0s - loss: 0.2571 - val_loss: 0.3893 Epoch 91/300 379/379 [==============================] - 0s - loss: 0.2545 - val_loss: 0.3976 Epoch 92/300 379/379 [==============================] - 0s - loss: 0.2560 - val_loss: 0.3975 Epoch 93/300 379/379 [==============================] - 0s - loss: 0.2514 - val_loss: 0.3884 Epoch 94/300 379/379 [==============================] - 0s - loss: 0.2491 - val_loss: 0.3942 Epoch 95/300 379/379 [==============================] - 0s - loss: 0.2476 - val_loss: 0.4082 Epoch 96/300 379/379 [==============================] - 0s - loss: 0.2545 - val_loss: 0.4052 Epoch 97/300 379/379 [==============================] - 0s - loss: 0.2572 - val_loss: 0.3929 Epoch 98/300 379/379 [==============================] - 0s - loss: 0.2534 - val_loss: 0.4080 Epoch 99/300 379/379 [==============================] - 0s - loss: 0.2426 - val_loss: 0.3957 Epoch 100/300 379/379 [==============================] - 0s - loss: 0.2470 - val_loss: 0.3848 Epoch 101/300 379/379 [==============================] - 0s - loss: 0.2468 - val_loss: 0.4250 Epoch 102/300 379/379 [==============================] - 0s - loss: 0.2470 - val_loss: 0.3907 Epoch 103/300 379/379 [==============================] - 0s - loss: 0.2526 - val_loss: 0.3869 Epoch 104/300 379/379 [==============================] - 0s - loss: 0.2475 - val_loss: 0.3903 Epoch 105/300 379/379 [==============================] - 0s - loss: 0.2495 - val_loss: 0.4023 Epoch 106/300 379/379 [==============================] - 0s - loss: 0.2448 - val_loss: 0.3891 Epoch 107/300 379/379 [==============================] - 0s - loss: 0.2447 - val_loss: 0.3835 Epoch 108/300 379/379 [==============================] - 0s - loss: 0.2586 - val_loss: 0.3841 Epoch 109/300 379/379 [==============================] - 0s - loss: 0.2460 - val_loss: 0.3824 Epoch 110/300 379/379 [==============================] - 0s - loss: 0.2485 - val_loss: 0.3835 Epoch 111/300 379/379 [==============================] - 0s - loss: 0.2499 - val_loss: 0.4530 Epoch 112/300 379/379 [==============================] - 0s - loss: 0.2554 - val_loss: 0.3818 Epoch 113/300 379/379 [==============================] - 0s - loss: 0.2482 - val_loss: 0.3921 Epoch 114/300 379/379 [==============================] - 0s - loss: 0.2543 - val_loss: 0.3820 Epoch 115/300 379/379 [==============================] - 0s - loss: 0.2444 - val_loss: 0.3927 Epoch 116/300 379/379 [==============================] - 0s - loss: 0.2441 - val_loss: 0.3809 Epoch 117/300 379/379 [==============================] - 0s - loss: 0.2481 - val_loss: 0.3860 Epoch 118/300 379/379 [==============================] - 0s - loss: 0.2465 - val_loss: 0.3802 Epoch 119/300 379/379 [==============================] - 0s - loss: 0.2578 - val_loss: 0.3793 Epoch 120/300 379/379 [==============================] - 0s - loss: 0.2463 - val_loss: 0.3850 Epoch 121/300 379/379 [==============================] - 0s - loss: 0.2425 - val_loss: 0.3822 Epoch 122/300 379/379 [==============================] - 0s - loss: 0.2432 - val_loss: 0.3955 Epoch 123/300 379/379 [==============================] - 0s - loss: 0.2514 - val_loss: 0.3793 Epoch 124/300 379/379 [==============================] - 0s - loss: 0.2448 - val_loss: 0.3780 Epoch 125/300 379/379 [==============================] - 0s - loss: 0.2476 - val_loss: 0.3786 Epoch 126/300 379/379 [==============================] - 0s - loss: 0.2464 - val_loss: 0.3790 Epoch 127/300 379/379 [==============================] - 0s - loss: 0.2411 - val_loss: 0.4202 Epoch 128/300 379/379 [==============================] - 0s - loss: 0.2454 - val_loss: 0.3791 Epoch 129/300 379/379 [==============================] - 0s - loss: 0.2461 - val_loss: 0.3789 Epoch 130/300 379/379 [==============================] - 0s - loss: 0.2460 - val_loss: 0.3769 Epoch 131/300 379/379 [==============================] - 0s - loss: 0.2457 - val_loss: 0.4310 Epoch 132/300 379/379 [==============================] - 0s - loss: 0.2497 - val_loss: 0.3831 Epoch 133/300 379/379 [==============================] - 0s - loss: 0.2463 - val_loss: 0.4142 Epoch 134/300 379/379 [==============================] - 0s - loss: 0.2424 - val_loss: 0.3786 Epoch 135/300 379/379 [==============================] - 0s - loss: 0.2484 - val_loss: 0.3925 Epoch 136/300 379/379 [==============================] - 0s - loss: 0.2486 - val_loss: 0.3798 Epoch 137/300 379/379 [==============================] - 0s - loss: 0.2473 - val_loss: 0.3765 Epoch 138/300 379/379 [==============================] - 0s - loss: 0.2384 - val_loss: 0.4272 Epoch 139/300 379/379 [==============================] - 0s - loss: 0.2438 - val_loss: 0.4206 Epoch 140/300 379/379 [==============================] - 0s - loss: 0.2528 - val_loss: 0.3870 Epoch 141/300 379/379 [==============================] - 0s - loss: 0.2409 - val_loss: 0.3747 Epoch 142/300 379/379 [==============================] - 0s - loss: 0.2508 - val_loss: 0.3891 Epoch 143/300 379/379 [==============================] - 0s - loss: 0.2483 - val_loss: 0.3770 Epoch 144/300 379/379 [==============================] - 0s - loss: 0.2467 - val_loss: 0.3965 Epoch 145/300 379/379 [==============================] - 0s - loss: 0.2406 - val_loss: 0.3729 Epoch 146/300 379/379 [==============================] - 0s - loss: 0.2604 - val_loss: 0.3736 Epoch 147/300 379/379 [==============================] - 0s - loss: 0.2394 - val_loss: 0.3838 Epoch 148/300 379/379 [==============================] - 0s - loss: 0.2421 - val_loss: 0.3900 Epoch 149/300 379/379 [==============================] - 0s - loss: 0.2419 - val_loss: 0.3724 Epoch 150/300 379/379 [==============================] - 0s - loss: 0.2437 - val_loss: 0.4042 Epoch 151/300 379/379 [==============================] - 0s - loss: 0.2687 - val_loss: 0.3730 Epoch 152/300 379/379 [==============================] - 0s - loss: 0.2389 - val_loss: 0.3743 Epoch 153/300 379/379 [==============================] - 0s - loss: 0.2382 - val_loss: 0.4131 Epoch 154/300 379/379 [==============================] - 0s - loss: 0.2503 - val_loss: 0.3731 Epoch 155/300 379/379 [==============================] - 0s - loss: 0.2491 - val_loss: 0.4038 Epoch 156/300 379/379 [==============================] - 0s - loss: 0.2386 - val_loss: 0.3803 Epoch 157/300 379/379 [==============================] - 0s - loss: 0.2393 - val_loss: 0.3876 Epoch 158/300 379/379 [==============================] - 0s - loss: 0.2482 - val_loss: 0.4039 Epoch 159/300 379/379 [==============================] - 0s - loss: 0.2441 - val_loss: 0.3713 Epoch 160/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.3727 Epoch 161/300 379/379 [==============================] - 0s - loss: 0.2394 - val_loss: 0.3735 Epoch 162/300 379/379 [==============================] - 0s - loss: 0.2418 - val_loss: 0.3848 Epoch 163/300 379/379 [==============================] - 0s - loss: 0.2424 - val_loss: 0.3717 Epoch 164/300 379/379 [==============================] - 0s - loss: 0.2395 - val_loss: 0.3707 Epoch 165/300 379/379 [==============================] - 0s - loss: 0.2382 - val_loss: 0.3869 Epoch 166/300 379/379 [==============================] - 0s - loss: 0.2463 - val_loss: 0.3701 Epoch 167/300 379/379 [==============================] - 0s - loss: 0.2442 - val_loss: 0.3902 Epoch 168/300 379/379 [==============================] - 0s - loss: 0.2390 - val_loss: 0.3887 Epoch 169/300 379/379 [==============================] - 0s - loss: 0.2371 - val_loss: 0.3909 Epoch 170/300 379/379 [==============================] - 0s - loss: 0.2553 - val_loss: 0.3823 Epoch 171/300 379/379 [==============================] - 0s - loss: 0.2415 - val_loss: 0.3715 Epoch 172/300 379/379 [==============================] - 0s - loss: 0.2466 - val_loss: 0.3830 Epoch 173/300 379/379 [==============================] - 0s - loss: 0.2397 - val_loss: 0.3821 Epoch 174/300 379/379 [==============================] - 0s - loss: 0.2436 - val_loss: 0.3726 Epoch 175/300 379/379 [==============================] - 0s - loss: 0.2358 - val_loss: 0.3738 Epoch 176/300 379/379 [==============================] - 0s - loss: 0.2547 - val_loss: 0.3682 Epoch 177/300 379/379 [==============================] - 0s - loss: 0.2394 - val_loss: 0.3720 Epoch 178/300 379/379 [==============================] - 0s - loss: 0.2397 - val_loss: 0.3694 Epoch 179/300 379/379 [==============================] - 0s - loss: 0.2343 - val_loss: 0.3740 Epoch 180/300 379/379 [==============================] - 0s - loss: 0.2393 - val_loss: 0.3748 Epoch 181/300 379/379 [==============================] - 0s - loss: 0.2381 - val_loss: 0.3679 Epoch 182/300 379/379 [==============================] - 0s - loss: 0.2396 - val_loss: 0.3694 Epoch 183/300 379/379 [==============================] - 0s - loss: 0.2435 - val_loss: 0.3694 Epoch 184/300 379/379 [==============================] - 0s - loss: 0.2441 - val_loss: 0.3762 Epoch 185/300 379/379 [==============================] - 0s - loss: 0.2403 - val_loss: 0.3868 Epoch 186/300 379/379 [==============================] - 0s - loss: 0.2474 - val_loss: 0.3754 Epoch 187/300 379/379 [==============================] - 0s - loss: 0.2462 - val_loss: 0.3909 Epoch 188/300 379/379 [==============================] - 0s - loss: 0.2411 - val_loss: 0.3717 Epoch 189/300 379/379 [==============================] - 0s - loss: 0.2557 - val_loss: 0.3883 Epoch 190/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.3679 Epoch 191/300 379/379 [==============================] - 0s - loss: 0.2493 - val_loss: 0.3670 Epoch 192/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.3798 Epoch 193/300 379/379 [==============================] - 0s - loss: 0.2406 - val_loss: 0.3784 Epoch 194/300 379/379 [==============================] - 0s - loss: 0.2412 - val_loss: 0.3969 Epoch 195/300 379/379 [==============================] - 0s - loss: 0.2388 - val_loss: 0.3669 Epoch 196/300 379/379 [==============================] - 0s - loss: 0.2359 - val_loss: 0.3873 Epoch 197/300 379/379 [==============================] - 0s - loss: 0.2519 - val_loss: 0.3745 Epoch 198/300 379/379 [==============================] - 0s - loss: 0.2450 - val_loss: 0.3721 Epoch 199/300 379/379 [==============================] - 0s - loss: 0.2438 - val_loss: 0.3665 Epoch 200/300 379/379 [==============================] - 0s - loss: 0.2386 - val_loss: 0.3841 Epoch 201/300 379/379 [==============================] - 0s - loss: 0.2383 - val_loss: 0.3668 Epoch 202/300 379/379 [==============================] - 0s - loss: 0.2503 - val_loss: 0.3668 Epoch 203/300 379/379 [==============================] - 0s - loss: 0.2380 - val_loss: 0.3905 Epoch 204/300 379/379 [==============================] - 0s - loss: 0.2489 - val_loss: 0.3795 Epoch 205/300 379/379 [==============================] - 0s - loss: 0.2451 - val_loss: 0.3723 Epoch 206/300 379/379 [==============================] - 0s - loss: 0.2470 - val_loss: 0.3766 Epoch 207/300 379/379 [==============================] - 0s - loss: 0.2409 - val_loss: 0.3713 Epoch 208/300 379/379 [==============================] - 0s - loss: 0.2413 - val_loss: 0.3676 Epoch 209/300 379/379 [==============================] - 0s - loss: 0.2372 - val_loss: 0.3674 Epoch 210/300 379/379 [==============================] - 0s - loss: 0.2476 - val_loss: 0.3943 Epoch 211/300 379/379 [==============================] - 0s - loss: 0.2384 - val_loss: 0.3706 Epoch 212/300 379/379 [==============================] - 0s - loss: 0.2415 - val_loss: 0.3653 Epoch 213/300 379/379 [==============================] - 0s - loss: 0.2404 - val_loss: 0.3655 Epoch 214/300 379/379 [==============================] - 0s - loss: 0.2409 - val_loss: 0.3651 Epoch 215/300 379/379 [==============================] - 0s - loss: 0.2422 - val_loss: 0.3689 Epoch 216/300 379/379 [==============================] - 0s - loss: 0.2360 - val_loss: 0.3635 Epoch 217/300 379/379 [==============================] - 0s - loss: 0.2423 - val_loss: 0.3831 Epoch 218/300 379/379 [==============================] - 0s - loss: 0.2395 - val_loss: 0.4073 Epoch 219/300 379/379 [==============================] - 0s - loss: 0.2588 - val_loss: 0.3877 Epoch 220/300 379/379 [==============================] - 0s - loss: 0.2370 - val_loss: 0.3685 Epoch 221/300 379/379 [==============================] - 0s - loss: 0.2595 - val_loss: 0.3634 Epoch 222/300 379/379 [==============================] - 0s - loss: 0.2378 - val_loss: 0.3790 Epoch 223/300 379/379 [==============================] - 0s - loss: 0.2375 - val_loss: 0.3643 Epoch 224/300 379/379 [==============================] - 0s - loss: 0.2385 - val_loss: 0.3970 Epoch 225/300 379/379 [==============================] - 0s - loss: 0.2401 - val_loss: 0.3629 Epoch 226/300 379/379 [==============================] - 0s - loss: 0.2420 - val_loss: 0.3658 Epoch 227/300 379/379 [==============================] - 0s - loss: 0.2414 - val_loss: 0.3646 Epoch 228/300 379/379 [==============================] - 0s - loss: 0.2391 - val_loss: 0.3765 Epoch 229/300 379/379 [==============================] - 0s - loss: 0.2361 - val_loss: 0.3655 Epoch 230/300 379/379 [==============================] - 0s - loss: 0.2428 - val_loss: 0.3772 Epoch 231/300 379/379 [==============================] - 0s - loss: 0.2410 - val_loss: 0.3750 Epoch 232/300 379/379 [==============================] - 0s - loss: 0.2398 - val_loss: 0.3641 Epoch 233/300 379/379 [==============================] - 0s - loss: 0.2486 - val_loss: 0.3626 Epoch 234/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.3746 Epoch 235/300 379/379 [==============================] - 0s - loss: 0.2374 - val_loss: 0.3646 Epoch 236/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.4253 Epoch 237/300 379/379 [==============================] - 0s - loss: 0.2499 - val_loss: 0.3652 Epoch 238/300 379/379 [==============================] - 0s - loss: 0.2395 - val_loss: 0.3735 Epoch 239/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.3758 Epoch 240/300 379/379 [==============================] - 0s - loss: 0.2375 - val_loss: 0.3861 Epoch 241/300 379/379 [==============================] - 0s - loss: 0.2375 - val_loss: 0.3647 Epoch 242/300 379/379 [==============================] - 0s - loss: 0.2570 - val_loss: 0.3613 Epoch 243/300 379/379 [==============================] - 0s - loss: 0.2410 - val_loss: 0.3765 Epoch 244/300 379/379 [==============================] - 0s - loss: 0.2455 - val_loss: 0.3604 Epoch 245/300 379/379 [==============================] - 0s - loss: 0.2483 - val_loss: 0.3644 Epoch 246/300 379/379 [==============================] - 0s - loss: 0.2483 - val_loss: 0.3617 Epoch 247/300 379/379 [==============================] - 0s - loss: 0.2460 - val_loss: 0.3978 Epoch 248/300 379/379 [==============================] - 0s - loss: 0.2351 - val_loss: 0.3618 Epoch 249/300 379/379 [==============================] - 0s - loss: 0.2414 - val_loss: 0.3619 Epoch 250/300 379/379 [==============================] - 0s - loss: 0.2437 - val_loss: 0.3619 Epoch 251/300 379/379 [==============================] - 0s - loss: 0.2337 - val_loss: 0.3617 Epoch 252/300 379/379 [==============================] - 0s - loss: 0.2399 - val_loss: 0.3635 Epoch 253/300 379/379 [==============================] - 0s - loss: 0.2342 - val_loss: 0.3689 Epoch 254/300 379/379 [==============================] - 0s - loss: 0.2371 - val_loss: 0.3884 Epoch 255/300 379/379 [==============================] - 0s - loss: 0.2390 - val_loss: 0.3673 Epoch 256/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.3635 Epoch 257/300 379/379 [==============================] - 0s - loss: 0.2396 - val_loss: 0.3952 Epoch 258/300 379/379 [==============================] - 0s - loss: 0.2331 - val_loss: 0.3660 Epoch 259/300 379/379 [==============================] - 0s - loss: 0.2439 - val_loss: 0.3749 Epoch 260/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.3623 Epoch 261/300 379/379 [==============================] - 0s - loss: 0.2450 - val_loss: 0.3614 Epoch 262/300 379/379 [==============================] - 0s - loss: 0.2326 - val_loss: 0.3782 Epoch 263/300 379/379 [==============================] - 0s - loss: 0.2381 - val_loss: 0.3730 Epoch 264/300 379/379 [==============================] - 0s - loss: 0.2439 - val_loss: 0.3682 Epoch 265/300 379/379 [==============================] - 0s - loss: 0.2366 - val_loss: 0.3616 Epoch 266/300 379/379 [==============================] - 0s - loss: 0.2356 - val_loss: 0.3623 Epoch 267/300 379/379 [==============================] - 0s - loss: 0.2406 - val_loss: 0.3690 Epoch 268/300 379/379 [==============================] - 0s - loss: 0.2369 - val_loss: 0.3620 Epoch 269/300 379/379 [==============================] - 0s - loss: 0.2495 - val_loss: 0.4115 Epoch 270/300 379/379 [==============================] - 0s - loss: 0.2421 - val_loss: 0.3627 Epoch 271/300 379/379 [==============================] - 0s - loss: 0.2378 - val_loss: 0.3644 Epoch 272/300 379/379 [==============================] - 0s - loss: 0.2410 - val_loss: 0.3595 Epoch 273/300 379/379 [==============================] - 0s - loss: 0.2371 - val_loss: 0.3594 Epoch 274/300 379/379 [==============================] - 0s - loss: 0.2393 - val_loss: 0.3650 Epoch 275/300 379/379 [==============================] - 0s - loss: 0.2391 - val_loss: 0.3603 Epoch 276/300 379/379 [==============================] - 0s - loss: 0.2463 - val_loss: 0.3691 Epoch 277/300 379/379 [==============================] - 0s - loss: 0.2369 - val_loss: 0.3651 Epoch 278/300 379/379 [==============================] - 0s - loss: 0.2447 - val_loss: 0.4331 Epoch 279/300 379/379 [==============================] - 0s - loss: 0.2554 - val_loss: 0.3642 Epoch 280/300 379/379 [==============================] - 0s - loss: 0.2376 - val_loss: 0.3607 Epoch 281/300 379/379 [==============================] - 0s - loss: 0.2370 - val_loss: 0.3610 Epoch 282/300 379/379 [==============================] - 0s - loss: 0.2370 - val_loss: 0.3672 Epoch 283/300 379/379 [==============================] - 0s - loss: 0.2365 - val_loss: 0.3661 Epoch 284/300 379/379 [==============================] - 0s - loss: 0.2367 - val_loss: 0.3711 Epoch 285/300 379/379 [==============================] - 0s - loss: 0.2359 - val_loss: 0.3590 Epoch 286/300 379/379 [==============================] - 0s - loss: 0.2358 - val_loss: 0.3618 Epoch 287/300 379/379 [==============================] - 0s - loss: 0.2447 - val_loss: 0.3763 Epoch 288/300 379/379 [==============================] - 0s - loss: 0.2363 - val_loss: 0.3610 Epoch 289/300 379/379 [==============================] - 0s - loss: 0.2326 - val_loss: 0.3616 Epoch 290/300 379/379 [==============================] - 0s - loss: 0.2326 - val_loss: 0.3599 Epoch 291/300 379/379 [==============================] - 0s - loss: 0.2444 - val_loss: 0.3743 Epoch 292/300 379/379 [==============================] - 0s - loss: 0.2324 - val_loss: 0.3731 Epoch 293/300 379/379 [==============================] - 0s - loss: 0.2327 - val_loss: 0.3598 Epoch 294/300 379/379 [==============================] - 0s - loss: 0.2524 - val_loss: 0.3637 Epoch 295/300 379/379 [==============================] - 0s - loss: 0.2415 - val_loss: 0.3653 Epoch 296/300 379/379 [==============================] - 0s - loss: 0.2403 - val_loss: 0.3588 Epoch 297/300 379/379 [==============================] - 0s - loss: 0.2519 - val_loss: 0.3612 Epoch 298/300 379/379 [==============================] - 0s - loss: 0.2378 - val_loss: 0.3648 Epoch 299/300 379/379 [==============================] - 0s - loss: 0.2429 - val_loss: 0.3585 Epoch 300/300 379/379 [==============================] - 0s - loss: 0.2377 - val_loss: 0.3584
En primer lugar es importante mencionar que con learning rate lr=0.2 (tal el trabajo solicitaba) los pesos de la red divergen en la primera época. Por lo tanto experimentalmente se cambió al valor lr=0.01 para el cuál si logra converger.
Los resultados se muestran a continuación:
history_plot(hist)
Como es de esperar, a medida que el número de épocas aumentan tanto el error de entrenamiento como el de validación disminuyen, siendo siempre el error de entrenamiente inferior al de validación, puesto que la red tiende a sobreajustar los datos con los que se le entrena.
Es importante notar que pasadas las 100 épocas no hay grandes mejoras en ambos errores.
Se repite el procedimiento anterior (con mismo lr=0.01), pero ahora con función de activación relu para la capa oculta.
# building the model
model = Sequential()
model.add(Dense(output_dim=200, input_dim=X_train_scaled.shape[1], init='uniform'))
model.add(Activation('relu'))
model.add(Dense(1, init='uniform'))
model.add(Activation('linear'))
# compiling the model
model.compile(optimizer=SGD(lr=0.01), loss='mean_squared_error')
# fitting the model
hist = model.fit(X_train_scaled.as_matrix(), y_train_scaled.as_matrix(), nb_epoch=300,
verbose=1, validation_data=(X_test_scaled.as_matrix(), y_test_scaled.as_matrix()))
Train on 379 samples, validate on 127 samples Epoch 1/300 379/379 [==============================] - 0s - loss: 0.9418 - val_loss: 0.8649 Epoch 2/300 379/379 [==============================] - 0s - loss: 0.8247 - val_loss: 0.7844 Epoch 3/300 379/379 [==============================] - 0s - loss: 0.7150 - val_loss: 0.7110 Epoch 4/300 379/379 [==============================] - 0s - loss: 0.6121 - val_loss: 0.6476 Epoch 5/300 379/379 [==============================] - 0s - loss: 0.5209 - val_loss: 0.5976 Epoch 6/300 379/379 [==============================] - 0s - loss: 0.4480 - val_loss: 0.5604 Epoch 7/300 379/379 [==============================] - 0s - loss: 0.3909 - val_loss: 0.5312 Epoch 8/300 379/379 [==============================] - 0s - loss: 0.3478 - val_loss: 0.5070 Epoch 9/300 379/379 [==============================] - 0s - loss: 0.3141 - val_loss: 0.4857 Epoch 10/300 379/379 [==============================] - 0s - loss: 0.2869 - val_loss: 0.4669 Epoch 11/300 379/379 [==============================] - 0s - loss: 0.2659 - val_loss: 0.4488 Epoch 12/300 379/379 [==============================] - 0s - loss: 0.2481 - val_loss: 0.4329 Epoch 13/300 379/379 [==============================] - 0s - loss: 0.2331 - val_loss: 0.4183 Epoch 14/300 379/379 [==============================] - 0s - loss: 0.2216 - val_loss: 0.4060 Epoch 15/300 379/379 [==============================] - 0s - loss: 0.2121 - val_loss: 0.3960 Epoch 16/300 379/379 [==============================] - 0s - loss: 0.2038 - val_loss: 0.3850 Epoch 17/300 379/379 [==============================] - 0s - loss: 0.1974 - val_loss: 0.3765 Epoch 18/300 379/379 [==============================] - 0s - loss: 0.1917 - val_loss: 0.3684 Epoch 19/300 379/379 [==============================] - 0s - loss: 0.1873 - val_loss: 0.3624 Epoch 20/300 379/379 [==============================] - 0s - loss: 0.1834 - val_loss: 0.3565 Epoch 21/300 379/379 [==============================] - 0s - loss: 0.1800 - val_loss: 0.3508 Epoch 22/300 379/379 [==============================] - 0s - loss: 0.1770 - val_loss: 0.3461 Epoch 23/300 379/379 [==============================] - 0s - loss: 0.1741 - val_loss: 0.3416 Epoch 24/300 379/379 [==============================] - 0s - loss: 0.1722 - val_loss: 0.3380 Epoch 25/300 379/379 [==============================] - 0s - loss: 0.1695 - val_loss: 0.3344 Epoch 26/300 379/379 [==============================] - 0s - loss: 0.1671 - val_loss: 0.3293 Epoch 27/300 379/379 [==============================] - 0s - loss: 0.1653 - val_loss: 0.3268 Epoch 28/300 379/379 [==============================] - 0s - loss: 0.1635 - val_loss: 0.3251 Epoch 29/300 379/379 [==============================] - 0s - loss: 0.1619 - val_loss: 0.3222 Epoch 30/300 379/379 [==============================] - 0s - loss: 0.1604 - val_loss: 0.3199 Epoch 31/300 379/379 [==============================] - 0s - loss: 0.1586 - val_loss: 0.3154 Epoch 32/300 379/379 [==============================] - 0s - loss: 0.1574 - val_loss: 0.3141 Epoch 33/300 379/379 [==============================] - 0s - loss: 0.1561 - val_loss: 0.3123 Epoch 34/300 379/379 [==============================] - 0s - loss: 0.1544 - val_loss: 0.3098 Epoch 35/300 379/379 [==============================] - 0s - loss: 0.1535 - val_loss: 0.3077 Epoch 36/300 379/379 [==============================] - 0s - loss: 0.1523 - val_loss: 0.3059 Epoch 37/300 379/379 [==============================] - 0s - loss: 0.1518 - val_loss: 0.3061 Epoch 38/300 379/379 [==============================] - 0s - loss: 0.1501 - val_loss: 0.3025 Epoch 39/300 379/379 [==============================] - 0s - loss: 0.1490 - val_loss: 0.3014 Epoch 40/300 379/379 [==============================] - 0s - loss: 0.1479 - val_loss: 0.3013 Epoch 41/300 379/379 [==============================] - 0s - loss: 0.1471 - val_loss: 0.2985 Epoch 42/300 379/379 [==============================] - 0s - loss: 0.1460 - val_loss: 0.2960 Epoch 43/300 379/379 [==============================] - 0s - loss: 0.1448 - val_loss: 0.2938 Epoch 44/300 379/379 [==============================] - 0s - loss: 0.1441 - val_loss: 0.2940 Epoch 45/300 379/379 [==============================] - 0s - loss: 0.1432 - val_loss: 0.2928 Epoch 46/300 379/379 [==============================] - 0s - loss: 0.1423 - val_loss: 0.2920 Epoch 47/300 379/379 [==============================] - 0s - loss: 0.1413 - val_loss: 0.2895 Epoch 48/300 379/379 [==============================] - 0s - loss: 0.1405 - val_loss: 0.2875 Epoch 49/300 379/379 [==============================] - 0s - loss: 0.1396 - val_loss: 0.2878 Epoch 50/300 379/379 [==============================] - 0s - loss: 0.1391 - val_loss: 0.2861 Epoch 51/300 379/379 [==============================] - 0s - loss: 0.1380 - val_loss: 0.2859 Epoch 52/300 379/379 [==============================] - 0s - loss: 0.1372 - val_loss: 0.2853 Epoch 53/300 379/379 [==============================] - 0s - loss: 0.1360 - val_loss: 0.2847 Epoch 54/300 379/379 [==============================] - 0s - loss: 0.1359 - val_loss: 0.2835 Epoch 55/300 379/379 [==============================] - 0s - loss: 0.1347 - val_loss: 0.2823 Epoch 56/300 379/379 [==============================] - 0s - loss: 0.1346 - val_loss: 0.2818 Epoch 57/300 379/379 [==============================] - 0s - loss: 0.1338 - val_loss: 0.2803 Epoch 58/300 379/379 [==============================] - 0s - loss: 0.1328 - val_loss: 0.2785 Epoch 59/300 379/379 [==============================] - 0s - loss: 0.1323 - val_loss: 0.2773 Epoch 60/300 379/379 [==============================] - 0s - loss: 0.1315 - val_loss: 0.2769 Epoch 61/300 379/379 [==============================] - 0s - loss: 0.1306 - val_loss: 0.2753 Epoch 62/300 379/379 [==============================] - 0s - loss: 0.1300 - val_loss: 0.2751 Epoch 63/300 379/379 [==============================] - 0s - loss: 0.1296 - val_loss: 0.2726 Epoch 64/300 379/379 [==============================] - 0s - loss: 0.1289 - val_loss: 0.2729 Epoch 65/300 379/379 [==============================] - 0s - loss: 0.1281 - val_loss: 0.2721 Epoch 66/300 379/379 [==============================] - 0s - loss: 0.1275 - val_loss: 0.2725 Epoch 67/300 379/379 [==============================] - 0s - loss: 0.1270 - val_loss: 0.2713 Epoch 68/300 379/379 [==============================] - 0s - loss: 0.1267 - val_loss: 0.2698 Epoch 69/300 379/379 [==============================] - 0s - loss: 0.1258 - val_loss: 0.2687 Epoch 70/300 379/379 [==============================] - 0s - loss: 0.1254 - val_loss: 0.2681 Epoch 71/300 379/379 [==============================] - 0s - loss: 0.1250 - val_loss: 0.2668 Epoch 72/300 379/379 [==============================] - 0s - loss: 0.1237 - val_loss: 0.2668 Epoch 73/300 379/379 [==============================] - 0s - loss: 0.1237 - val_loss: 0.2660 Epoch 74/300 379/379 [==============================] - 0s - loss: 0.1228 - val_loss: 0.2651 Epoch 75/300 379/379 [==============================] - 0s - loss: 0.1223 - val_loss: 0.2659 Epoch 76/300 379/379 [==============================] - 0s - loss: 0.1215 - val_loss: 0.2647 Epoch 77/300 379/379 [==============================] - 0s - loss: 0.1208 - val_loss: 0.2644 Epoch 78/300 379/379 [==============================] - 0s - loss: 0.1210 - val_loss: 0.2618 Epoch 79/300 379/379 [==============================] - 0s - loss: 0.1201 - val_loss: 0.2612 Epoch 80/300 379/379 [==============================] - 0s - loss: 0.1196 - val_loss: 0.2601 Epoch 81/300 379/379 [==============================] - 0s - loss: 0.1187 - val_loss: 0.2601 Epoch 82/300 379/379 [==============================] - 0s - loss: 0.1190 - val_loss: 0.2582 Epoch 83/300 379/379 [==============================] - 0s - loss: 0.1181 - val_loss: 0.2581 Epoch 84/300 379/379 [==============================] - 0s - loss: 0.1176 - val_loss: 0.2584 Epoch 85/300 379/379 [==============================] - 0s - loss: 0.1166 - val_loss: 0.2581 Epoch 86/300 379/379 [==============================] - 0s - loss: 0.1165 - val_loss: 0.2578 Epoch 87/300 379/379 [==============================] - 0s - loss: 0.1163 - val_loss: 0.2591 Epoch 88/300 379/379 [==============================] - 0s - loss: 0.1156 - val_loss: 0.2579 Epoch 89/300 379/379 [==============================] - 0s - loss: 0.1148 - val_loss: 0.2559 Epoch 90/300 379/379 [==============================] - 0s - loss: 0.1145 - val_loss: 0.2535 Epoch 91/300 379/379 [==============================] - 0s - loss: 0.1145 - val_loss: 0.2521 Epoch 92/300 379/379 [==============================] - 0s - loss: 0.1137 - val_loss: 0.2517 Epoch 93/300 379/379 [==============================] - 0s - loss: 0.1132 - val_loss: 0.2531 Epoch 94/300 379/379 [==============================] - 0s - loss: 0.1126 - val_loss: 0.2528 Epoch 95/300 379/379 [==============================] - 0s - loss: 0.1129 - val_loss: 0.2508 Epoch 96/300 379/379 [==============================] - 0s - loss: 0.1125 - val_loss: 0.2518 Epoch 97/300 379/379 [==============================] - 0s - loss: 0.1115 - val_loss: 0.2509 Epoch 98/300 379/379 [==============================] - 0s - loss: 0.1111 - val_loss: 0.2498 Epoch 99/300 379/379 [==============================] - 0s - loss: 0.1107 - val_loss: 0.2494 Epoch 100/300 379/379 [==============================] - 0s - loss: 0.1101 - val_loss: 0.2504 Epoch 101/300 379/379 [==============================] - 0s - loss: 0.1103 - val_loss: 0.2471 Epoch 102/300 379/379 [==============================] - 0s - loss: 0.1090 - val_loss: 0.2492 Epoch 103/300 379/379 [==============================] - 0s - loss: 0.1092 - val_loss: 0.2464 Epoch 104/300 379/379 [==============================] - 0s - loss: 0.1092 - val_loss: 0.2457 Epoch 105/300 379/379 [==============================] - 0s - loss: 0.1089 - val_loss: 0.2436 Epoch 106/300 379/379 [==============================] - 0s - loss: 0.1081 - val_loss: 0.2458 Epoch 107/300 379/379 [==============================] - 0s - loss: 0.1077 - val_loss: 0.2465 Epoch 108/300 379/379 [==============================] - 0s - loss: 0.1068 - val_loss: 0.2458 Epoch 109/300 379/379 [==============================] - 0s - loss: 0.1066 - val_loss: 0.2460 Epoch 110/300 379/379 [==============================] - 0s - loss: 0.1067 - val_loss: 0.2447 Epoch 111/300 379/379 [==============================] - 0s - loss: 0.1060 - val_loss: 0.2434 Epoch 112/300 379/379 [==============================] - 0s - loss: 0.1057 - val_loss: 0.2422 Epoch 113/300 379/379 [==============================] - 0s - loss: 0.1057 - val_loss: 0.2429 Epoch 114/300 379/379 [==============================] - 0s - loss: 0.1053 - val_loss: 0.2440 Epoch 115/300 379/379 [==============================] - 0s - loss: 0.1049 - val_loss: 0.2423 Epoch 116/300 379/379 [==============================] - 0s - loss: 0.1047 - val_loss: 0.2397 Epoch 117/300 379/379 [==============================] - 0s - loss: 0.1038 - val_loss: 0.2404 Epoch 118/300 379/379 [==============================] - 0s - loss: 0.1038 - val_loss: 0.2407 Epoch 119/300 379/379 [==============================] - 0s - loss: 0.1036 - val_loss: 0.2409 Epoch 120/300 379/379 [==============================] - 0s - loss: 0.1026 - val_loss: 0.2405 Epoch 121/300 379/379 [==============================] - 0s - loss: 0.1029 - val_loss: 0.2371 Epoch 122/300 379/379 [==============================] - 0s - loss: 0.1029 - val_loss: 0.2386 Epoch 123/300 379/379 [==============================] - 0s - loss: 0.1020 - val_loss: 0.2375 Epoch 124/300 379/379 [==============================] - 0s - loss: 0.1017 - val_loss: 0.2378 Epoch 125/300 379/379 [==============================] - 0s - loss: 0.1018 - val_loss: 0.2380 Epoch 126/300 379/379 [==============================] - 0s - loss: 0.1010 - val_loss: 0.2374 Epoch 127/300 379/379 [==============================] - 0s - loss: 0.1012 - val_loss: 0.2369 Epoch 128/300 379/379 [==============================] - 0s - loss: 0.1007 - val_loss: 0.2367 Epoch 129/300 379/379 [==============================] - 0s - loss: 0.1004 - val_loss: 0.2361 Epoch 130/300 379/379 [==============================] - 0s - loss: 0.1002 - val_loss: 0.2362 Epoch 131/300 379/379 [==============================] - 0s - loss: 0.0997 - val_loss: 0.2360 Epoch 132/300 379/379 [==============================] - 0s - loss: 0.0998 - val_loss: 0.2356 Epoch 133/300 379/379 [==============================] - 0s - loss: 0.0994 - val_loss: 0.2345 Epoch 134/300 379/379 [==============================] - 0s - loss: 0.0992 - val_loss: 0.2358 Epoch 135/300 379/379 [==============================] - 0s - loss: 0.0986 - val_loss: 0.2329 Epoch 136/300 379/379 [==============================] - 0s - loss: 0.0985 - val_loss: 0.2360 Epoch 137/300 379/379 [==============================] - 0s - loss: 0.0982 - val_loss: 0.2358 Epoch 138/300 379/379 [==============================] - 0s - loss: 0.0979 - val_loss: 0.2349 Epoch 139/300 379/379 [==============================] - 0s - loss: 0.0973 - val_loss: 0.2325 Epoch 140/300 379/379 [==============================] - 0s - loss: 0.0969 - val_loss: 0.2321 Epoch 141/300 379/379 [==============================] - 0s - loss: 0.0973 - val_loss: 0.2335 Epoch 142/300 379/379 [==============================] - 0s - loss: 0.0970 - val_loss: 0.2333 Epoch 143/300 379/379 [==============================] - 0s - loss: 0.0969 - val_loss: 0.2346 Epoch 144/300 379/379 [==============================] - 0s - loss: 0.0967 - val_loss: 0.2313 Epoch 145/300 379/379 [==============================] - 0s - loss: 0.0961 - val_loss: 0.2314 Epoch 146/300 379/379 [==============================] - 0s - loss: 0.0961 - val_loss: 0.2310 Epoch 147/300 379/379 [==============================] - 0s - loss: 0.0955 - val_loss: 0.2314 Epoch 148/300 379/379 [==============================] - 0s - loss: 0.0954 - val_loss: 0.2289 Epoch 149/300 379/379 [==============================] - 0s - loss: 0.0948 - val_loss: 0.2292 Epoch 150/300 379/379 [==============================] - 0s - loss: 0.0948 - val_loss: 0.2278 Epoch 151/300 379/379 [==============================] - 0s - loss: 0.0945 - val_loss: 0.2290 Epoch 152/300 379/379 [==============================] - 0s - loss: 0.0943 - val_loss: 0.2291 Epoch 153/300 379/379 [==============================] - 0s - loss: 0.0943 - val_loss: 0.2295 Epoch 154/300 379/379 [==============================] - 0s - loss: 0.0935 - val_loss: 0.2285 Epoch 155/300 379/379 [==============================] - 0s - loss: 0.0936 - val_loss: 0.2282 Epoch 156/300 379/379 [==============================] - 0s - loss: 0.0935 - val_loss: 0.2273 Epoch 157/300 379/379 [==============================] - 0s - loss: 0.0927 - val_loss: 0.2268 Epoch 158/300 379/379 [==============================] - 0s - loss: 0.0932 - val_loss: 0.2258 Epoch 159/300 379/379 [==============================] - 0s - loss: 0.0927 - val_loss: 0.2270 Epoch 160/300 379/379 [==============================] - 0s - loss: 0.0922 - val_loss: 0.2261 Epoch 161/300 379/379 [==============================] - 0s - loss: 0.0920 - val_loss: 0.2277 Epoch 162/300 379/379 [==============================] - 0s - loss: 0.0917 - val_loss: 0.2264 Epoch 163/300 379/379 [==============================] - 0s - loss: 0.0914 - val_loss: 0.2272 Epoch 164/300 379/379 [==============================] - 0s - loss: 0.0915 - val_loss: 0.2258 Epoch 165/300 379/379 [==============================] - 0s - loss: 0.0918 - val_loss: 0.2250 Epoch 166/300 379/379 [==============================] - 0s - loss: 0.0911 - val_loss: 0.2245 Epoch 167/300 379/379 [==============================] - 0s - loss: 0.0904 - val_loss: 0.2237 Epoch 168/300 379/379 [==============================] - 0s - loss: 0.0904 - val_loss: 0.2250 Epoch 169/300 379/379 [==============================] - 0s - loss: 0.0899 - val_loss: 0.2261 Epoch 170/300 379/379 [==============================] - 0s - loss: 0.0902 - val_loss: 0.2250 Epoch 171/300 379/379 [==============================] - 0s - loss: 0.0900 - val_loss: 0.2223 Epoch 172/300 379/379 [==============================] - 0s - loss: 0.0897 - val_loss: 0.2225 Epoch 173/300 379/379 [==============================] - 0s - loss: 0.0893 - val_loss: 0.2235 Epoch 174/300 379/379 [==============================] - 0s - loss: 0.0892 - val_loss: 0.2240 Epoch 175/300 379/379 [==============================] - 0s - loss: 0.0890 - val_loss: 0.2231 Epoch 176/300 379/379 [==============================] - 0s - loss: 0.0889 - val_loss: 0.2240 Epoch 177/300 379/379 [==============================] - 0s - loss: 0.0887 - val_loss: 0.2239 Epoch 178/300 379/379 [==============================] - 0s - loss: 0.0886 - val_loss: 0.2218 Epoch 179/300 379/379 [==============================] - 0s - loss: 0.0884 - val_loss: 0.2217 Epoch 180/300 379/379 [==============================] - 0s - loss: 0.0880 - val_loss: 0.2212 Epoch 181/300 379/379 [==============================] - 0s - loss: 0.0877 - val_loss: 0.2214 Epoch 182/300 379/379 [==============================] - 0s - loss: 0.0874 - val_loss: 0.2204 Epoch 183/300 379/379 [==============================] - 0s - loss: 0.0872 - val_loss: 0.2206 Epoch 184/300 379/379 [==============================] - 0s - loss: 0.0872 - val_loss: 0.2194 Epoch 185/300 379/379 [==============================] - 0s - loss: 0.0867 - val_loss: 0.2190 Epoch 186/300 379/379 [==============================] - 0s - loss: 0.0868 - val_loss: 0.2190 Epoch 187/300 379/379 [==============================] - 0s - loss: 0.0865 - val_loss: 0.2197 Epoch 188/300 379/379 [==============================] - 0s - loss: 0.0861 - val_loss: 0.2181 Epoch 189/300 379/379 [==============================] - 0s - loss: 0.0858 - val_loss: 0.2159 Epoch 190/300 379/379 [==============================] - 0s - loss: 0.0856 - val_loss: 0.2159 Epoch 191/300 379/379 [==============================] - 0s - loss: 0.0856 - val_loss: 0.2186 Epoch 192/300 379/379 [==============================] - 0s - loss: 0.0854 - val_loss: 0.2192 Epoch 193/300 379/379 [==============================] - 0s - loss: 0.0851 - val_loss: 0.2183 Epoch 194/300 379/379 [==============================] - 0s - loss: 0.0847 - val_loss: 0.2194 Epoch 195/300 379/379 [==============================] - 0s - loss: 0.0847 - val_loss: 0.2166 Epoch 196/300 379/379 [==============================] - 0s - loss: 0.0850 - val_loss: 0.2165 Epoch 197/300 379/379 [==============================] - 0s - loss: 0.0846 - val_loss: 0.2176 Epoch 198/300 379/379 [==============================] - 0s - loss: 0.0842 - val_loss: 0.2163 Epoch 199/300 379/379 [==============================] - 0s - loss: 0.0841 - val_loss: 0.2171 Epoch 200/300 379/379 [==============================] - 0s - loss: 0.0838 - val_loss: 0.2167 Epoch 201/300 379/379 [==============================] - 0s - loss: 0.0832 - val_loss: 0.2152 Epoch 202/300 379/379 [==============================] - 0s - loss: 0.0833 - val_loss: 0.2177 Epoch 203/300 379/379 [==============================] - 0s - loss: 0.0832 - val_loss: 0.2195 Epoch 204/300 379/379 [==============================] - 0s - loss: 0.0833 - val_loss: 0.2174 Epoch 205/300 379/379 [==============================] - 0s - loss: 0.0828 - val_loss: 0.2167 Epoch 206/300 379/379 [==============================] - 0s - loss: 0.0825 - val_loss: 0.2170 Epoch 207/300 379/379 [==============================] - 0s - loss: 0.0824 - val_loss: 0.2161 Epoch 208/300 379/379 [==============================] - 0s - loss: 0.0824 - val_loss: 0.2160 Epoch 209/300 379/379 [==============================] - 0s - loss: 0.0819 - val_loss: 0.2145 Epoch 210/300 379/379 [==============================] - 0s - loss: 0.0818 - val_loss: 0.2137 Epoch 211/300 379/379 [==============================] - 0s - loss: 0.0820 - val_loss: 0.2146 Epoch 212/300 379/379 [==============================] - 0s - loss: 0.0816 - val_loss: 0.2137 Epoch 213/300 379/379 [==============================] - 0s - loss: 0.0812 - val_loss: 0.2134 Epoch 214/300 379/379 [==============================] - 0s - loss: 0.0811 - val_loss: 0.2145 Epoch 215/300 379/379 [==============================] - 0s - loss: 0.0813 - val_loss: 0.2139 Epoch 216/300 379/379 [==============================] - 0s - loss: 0.0808 - val_loss: 0.2132 Epoch 217/300 379/379 [==============================] - 0s - loss: 0.0802 - val_loss: 0.2129 Epoch 218/300 379/379 [==============================] - 0s - loss: 0.0800 - val_loss: 0.2144 Epoch 219/300 379/379 [==============================] - 0s - loss: 0.0800 - val_loss: 0.2130 Epoch 220/300 379/379 [==============================] - 0s - loss: 0.0798 - val_loss: 0.2143 Epoch 221/300 379/379 [==============================] - 0s - loss: 0.0801 - val_loss: 0.2136 Epoch 222/300 379/379 [==============================] - 0s - loss: 0.0795 - val_loss: 0.2141 Epoch 223/300 379/379 [==============================] - 0s - loss: 0.0797 - val_loss: 0.2127 Epoch 224/300 379/379 [==============================] - 0s - loss: 0.0794 - val_loss: 0.2124 Epoch 225/300 379/379 [==============================] - 0s - loss: 0.0793 - val_loss: 0.2113 Epoch 226/300 379/379 [==============================] - 0s - loss: 0.0794 - val_loss: 0.2131 Epoch 227/300 379/379 [==============================] - 0s - loss: 0.0794 - val_loss: 0.2108 Epoch 228/300 379/379 [==============================] - 0s - loss: 0.0787 - val_loss: 0.2122 Epoch 229/300 379/379 [==============================] - 0s - loss: 0.0788 - val_loss: 0.2110 Epoch 230/300 379/379 [==============================] - 0s - loss: 0.0781 - val_loss: 0.2078 Epoch 231/300 379/379 [==============================] - 0s - loss: 0.0780 - val_loss: 0.2096 Epoch 232/300 379/379 [==============================] - 0s - loss: 0.0781 - val_loss: 0.2104 Epoch 233/300 379/379 [==============================] - 0s - loss: 0.0779 - val_loss: 0.2090 Epoch 234/300 379/379 [==============================] - 0s - loss: 0.0774 - val_loss: 0.2085 Epoch 235/300 379/379 [==============================] - 0s - loss: 0.0781 - val_loss: 0.2102 Epoch 236/300 379/379 [==============================] - 0s - loss: 0.0775 - val_loss: 0.2091 Epoch 237/300 379/379 [==============================] - 0s - loss: 0.0770 - val_loss: 0.2101 Epoch 238/300 379/379 [==============================] - 0s - loss: 0.0773 - val_loss: 0.2107 Epoch 239/300 379/379 [==============================] - 0s - loss: 0.0772 - val_loss: 0.2088 Epoch 240/300 379/379 [==============================] - 0s - loss: 0.0768 - val_loss: 0.2067 Epoch 241/300 379/379 [==============================] - 0s - loss: 0.0769 - val_loss: 0.2079 Epoch 242/300 379/379 [==============================] - 0s - loss: 0.0764 - val_loss: 0.2073 Epoch 243/300 379/379 [==============================] - 0s - loss: 0.0760 - val_loss: 0.2072 Epoch 244/300 379/379 [==============================] - 0s - loss: 0.0766 - val_loss: 0.2061 Epoch 245/300 379/379 [==============================] - 0s - loss: 0.0757 - val_loss: 0.2071 Epoch 246/300 379/379 [==============================] - 0s - loss: 0.0757 - val_loss: 0.2089 Epoch 247/300 379/379 [==============================] - 0s - loss: 0.0759 - val_loss: 0.2077 Epoch 248/300 379/379 [==============================] - 0s - loss: 0.0760 - val_loss: 0.2086 Epoch 249/300 379/379 [==============================] - 0s - loss: 0.0755 - val_loss: 0.2084 Epoch 250/300 379/379 [==============================] - 0s - loss: 0.0755 - val_loss: 0.2091 Epoch 251/300 379/379 [==============================] - 0s - loss: 0.0750 - val_loss: 0.2096 Epoch 252/300 379/379 [==============================] - 0s - loss: 0.0749 - val_loss: 0.2063 Epoch 253/300 379/379 [==============================] - 0s - loss: 0.0751 - val_loss: 0.2066 Epoch 254/300 379/379 [==============================] - 0s - loss: 0.0744 - val_loss: 0.2091 Epoch 255/300 379/379 [==============================] - 0s - loss: 0.0744 - val_loss: 0.2076 Epoch 256/300 379/379 [==============================] - 0s - loss: 0.0748 - val_loss: 0.2074 Epoch 257/300 379/379 [==============================] - 0s - loss: 0.0741 - val_loss: 0.2056 Epoch 258/300 379/379 [==============================] - 0s - loss: 0.0741 - val_loss: 0.2069 Epoch 259/300 379/379 [==============================] - 0s - loss: 0.0739 - val_loss: 0.2082 Epoch 260/300 379/379 [==============================] - 0s - loss: 0.0741 - val_loss: 0.2050 Epoch 261/300 379/379 [==============================] - 0s - loss: 0.0737 - val_loss: 0.2091 Epoch 262/300 379/379 [==============================] - 0s - loss: 0.0736 - val_loss: 0.2066 Epoch 263/300 379/379 [==============================] - 0s - loss: 0.0736 - val_loss: 0.2079 Epoch 264/300 379/379 [==============================] - 0s - loss: 0.0733 - val_loss: 0.2043 Epoch 265/300 379/379 [==============================] - 0s - loss: 0.0732 - val_loss: 0.2068 Epoch 266/300 379/379 [==============================] - 0s - loss: 0.0731 - val_loss: 0.2064 Epoch 267/300 379/379 [==============================] - 0s - loss: 0.0727 - val_loss: 0.2052 Epoch 268/300 379/379 [==============================] - 0s - loss: 0.0727 - val_loss: 0.2055 Epoch 269/300 379/379 [==============================] - 0s - loss: 0.0726 - val_loss: 0.2103 Epoch 270/300 379/379 [==============================] - 0s - loss: 0.0725 - val_loss: 0.2050 Epoch 271/300 379/379 [==============================] - 0s - loss: 0.0725 - val_loss: 0.2076 Epoch 272/300 379/379 [==============================] - 0s - loss: 0.0721 - val_loss: 0.2051 Epoch 273/300 379/379 [==============================] - 0s - loss: 0.0722 - val_loss: 0.2033 Epoch 274/300 379/379 [==============================] - 0s - loss: 0.0720 - val_loss: 0.2047 Epoch 275/300 379/379 [==============================] - 0s - loss: 0.0715 - val_loss: 0.2050 Epoch 276/300 379/379 [==============================] - 0s - loss: 0.0717 - val_loss: 0.2036 Epoch 277/300 379/379 [==============================] - 0s - loss: 0.0717 - val_loss: 0.2026 Epoch 278/300 379/379 [==============================] - 0s - loss: 0.0712 - val_loss: 0.2053 Epoch 279/300 379/379 [==============================] - 0s - loss: 0.0712 - val_loss: 0.2030 Epoch 280/300 379/379 [==============================] - 0s - loss: 0.0715 - val_loss: 0.2030 Epoch 281/300 379/379 [==============================] - 0s - loss: 0.0708 - val_loss: 0.2065 Epoch 282/300 379/379 [==============================] - 0s - loss: 0.0709 - val_loss: 0.2037 Epoch 283/300 379/379 [==============================] - 0s - loss: 0.0705 - val_loss: 0.2017 Epoch 284/300 379/379 [==============================] - 0s - loss: 0.0708 - val_loss: 0.2035 Epoch 285/300 379/379 [==============================] - 0s - loss: 0.0704 - val_loss: 0.2042 Epoch 286/300 379/379 [==============================] - 0s - loss: 0.0702 - val_loss: 0.2049 Epoch 287/300 379/379 [==============================] - 0s - loss: 0.0701 - val_loss: 0.2049 Epoch 288/300 379/379 [==============================] - 0s - loss: 0.0704 - val_loss: 0.2038 Epoch 289/300 379/379 [==============================] - 0s - loss: 0.0699 - val_loss: 0.2024 Epoch 290/300 379/379 [==============================] - 0s - loss: 0.0699 - val_loss: 0.2050 Epoch 291/300 379/379 [==============================] - 0s - loss: 0.0695 - val_loss: 0.2030 Epoch 292/300 379/379 [==============================] - 0s - loss: 0.0694 - val_loss: 0.2013 Epoch 293/300 379/379 [==============================] - 0s - loss: 0.0691 - val_loss: 0.2019 Epoch 294/300 379/379 [==============================] - 0s - loss: 0.0696 - val_loss: 0.2016 Epoch 295/300 379/379 [==============================] - 0s - loss: 0.0696 - val_loss: 0.2015 Epoch 296/300 379/379 [==============================] - 0s - loss: 0.0688 - val_loss: 0.2039 Epoch 297/300 379/379 [==============================] - 0s - loss: 0.0689 - val_loss: 0.2023 Epoch 298/300 379/379 [==============================] - 0s - loss: 0.0689 - val_loss: 0.2029 Epoch 299/300 379/379 [==============================] - 0s - loss: 0.0693 - val_loss: 0.1996 Epoch 300/300 379/379 [==============================] - 0s - loss: 0.0683 - val_loss: 0.2022
history_plot(hist)
El resultado anterior muestra una gran mejora respecto de lo obtenido con funciones de activación sigmoidales. La siguiente tabla resume los resultados finales para ambos casos.
sigmoid |
relu |
|
|---|---|---|
| Training | 0.2377 | 0.0683 |
| Validation | 0.3584 | 0.2022 |
Adicionalmente como se aprecia de los gráficos, la convergencia es mucho más uniforme en la caso de relu, esto es, se nota una tendencia decreciente clara y casi monotonamente decreciente.
n_lr = 20
learning_rate = np.linspace(0, 0.03, n_lr+1)[1::]
hist_list = list()
for i in range(n_lr):
# building the model
model = Sequential()
model.add(Dense(output_dim=200, input_dim=X_train_scaled.shape[1], init='uniform'))
model.add(Activation('sigmoid'))
model.add(Dense(1, init='uniform'))
model.add(Activation('linear'))
model.compile(optimizer=SGD(lr=learning_rate[i]), loss='mean_squared_error')
# training the network
hist = model.fit(X_train_scaled.as_matrix(), y_train_scaled.as_matrix(), nb_epoch=300,
verbose=0, validation_data=(X_test_scaled.as_matrix(), y_test_scaled.as_matrix()))
# storing the results
hist_list.append(hist)
# making training and validation error plots
for i,hist in enumerate(hist_list):
title = 'Mean Squared Training and Validation Error. Learning rate: {0}'.format(learning_rate[i])
history_plot(hist, title)
Xm = X_train_scaled.as_matrix()
ym = y_train_scaled.as_matrix()
kfold = KFold(len(Xm), 5)
cvscores = []
for i, (train, val) in enumerate(kfold):
# create model
model = Sequential()
model.add(Dense(200, input_dim=Xm.shape[1], init='uniform'))
model.add(Activation('sigmoid'))
model.add(Dense(1, init='uniform'))
model.add(Activation('linear'))
# Compile model
sgd = SGD(lr=0.01)
model.compile(optimizer=sgd, loss='mean_squared_error')
# Fit the model
model.fit(Xm[train], ym[train], nb_epoch=300, verbose=0)
# evaluate the model
scores = model.evaluate(Xm[val], ym[val])
cvscores.append(scores)
mse_cv = np.mean(cvscores)
32/75 [===========>..................] - ETA: 0s
mse_cv
0.26704990021492303
kfold = KFold(len(Xm), 10)
cvscores = []
for i, (train, val) in enumerate(kfold):
# create model
model = Sequential()
model.add(Dense(200, input_dim=Xm.shape[1], init='uniform'))
model.add(Activation('sigmoid'))
model.add(Dense(1, init='uniform'))
model.add(Activation('linear'))
# Compile model
sgd = SGD(lr=0.01)
model.compile(optimizer=sgd, loss='mean_squared_error')
# Fit the model
model.fit(Xm[train], ym[train], nb_epoch=300, verbose=0)
# evaluate the model
scores = model.evaluate(Xm[val], ym[val])
cvscores.append(scores)
mse_cv = np.mean(cvscores)
32/37 [========================>.....] - ETA: 0s
mse_cv
0.27104718351053075
n_decay = 10
learning_decay = np.logspace(-6, 0, n_decay)
hist_list = list()
for i in range(n_decay):
# building the model
model = Sequential()
model.add(Dense(output_dim=200, input_dim=X_train_scaled.shape[1], init='uniform'))
model.add(Activation('sigmoid'))
model.add(Dense(1, init='uniform'))
model.add(Activation('linear'))
model.compile(optimizer=SGD(lr=0.01, decay=learning_decay[i]), loss='mean_squared_error')
# training the network
hist = model.fit(X_train_scaled.as_matrix(), y_train_scaled.as_matrix(), nb_epoch=300,
verbose=0, validation_data=(X_test_scaled.as_matrix(), y_test_scaled.as_matrix()))
# storing the results
hist_list.append(hist)
# making training and validation error plots
for i,hist in enumerate(hist_list):
title = 'Mean Squared Training and Validation Error. Decay factor: {0}'.format(learning_decay[i])
history_plot(hist, title)
n_momentum = 20
momentum = np.linspace(0., 1., n_momentum+1)[1::]
hist_list = list()
for i in range(n_momentum):
# building the model
model = Sequential()
model.add(Dense(output_dim=200, input_dim=X_train_scaled.shape[1], init='uniform'))
model.add(Activation('sigmoid'))
model.add(Dense(1, init='uniform'))
model.add(Activation('linear'))
model.compile(optimizer=SGD(lr=0.01, decay=momentum[i]), loss='mean_squared_error')
# training the network
hist = model.fit(X_train_scaled.as_matrix(), y_train_scaled.as_matrix(), nb_epoch=300,
verbose=0, validation_data=(X_test_scaled.as_matrix(), y_test_scaled.as_matrix()))
# storing the results
hist_list.append(hist)
# making training and validation error plots
for i,hist in enumerate(hist_list):
title = 'Mean Squared Training and Validation Error. Momentum: {0}'.format(momentum[i])
history_plot(hist, title)
n_batches = 20
batch_sizes = np.round(np.linspace(1,X_train_scaled.shape[0],n_batches))
print(batch_sizes)
[ 1. 21. 41. 61. 81. 100. 120. 140. 160. 180. 200. 220. 240. 260. 280. 299. 319. 339. 359. 379.]
hist_list = list()
for i in range(n_batches):
# building the model
model = Sequential()
model.add(Dense(output_dim=200, input_dim=X_train_scaled.shape[1], init='uniform'))
model.add(Activation('sigmoid'))
model.add(Dense(1, init='uniform'))
model.add(Activation('linear'))
model.compile(optimizer=SGD(lr=0.01), loss='mean_squared_error')
# training the network
hist = model.fit(X_train_scaled.as_matrix(), y_train_scaled.as_matrix(), nb_epoch=300, batch_size=batch_sizes[i],
verbose=0, validation_data=(X_test_scaled.as_matrix(), y_test_scaled.as_matrix()))
# storing the results
hist_list.append(hist)
# making training and validation error plots
for i,hist in enumerate(hist_list):
title = 'Mean Squared Training and Validation Error. Batch Size: {0}'.format(batch_sizes[i])
history_plot(hist, title)